{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yxig5CdZuHb9"
      },
      "source": [
        "# CountGD - Multi-Modal Open-World Object Counting\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9wyM6J2HuHb-"
      },
      "source": [
        "## Setup\n",
        "\n",
        "The following cells will setup the runtime environment with the following\n",
        "\n",
        "- Mount Google Drive\n",
        "- Install dependencies for running the model\n",
        "- Load the model into memory"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jn061Tl8uHb-"
      },
      "source": [
        "### Mount Google Drive (if running on colab)\n",
        "\n",
        "The following bit of code will mount your Google Drive folder at `/content/drive`, allowing you to process files directly from it as well as store the results alongside it.\n",
        "\n",
        "Once you execute the next cell, you will be requested to share access with the notebook. Please follow the instructions on screen to do so.\n",
        "If you are not running this on colab, you will still be able to use the files available on your environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "collapsed": true,
        "id": "DkSUXqMPuHb-",
        "outputId": "d2aeeee9-533e-4ff8-d443-9536c27cd624"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n",
            "env: RUNNING_IN_COLAB=True\n"
          ]
        }
      ],
      "source": [
        "# Check if running colab\n",
        "import logging\n",
        "from pathlib import Path\n",
        "\n",
        "logging.basicConfig(\n",
        "  level=logging.INFO,\n",
        "  format='%(asctime)s %(levelname)-8s %(name)s %(message)s'\n",
        ")\n",
        "try:\n",
        "    import google.colab\n",
        "    RUNNING_IN_COLAB = True\n",
        "except:\n",
        "    RUNNING_IN_COLAB = False\n",
        "\n",
        "if RUNNING_IN_COLAB:\n",
        "    from google.colab import drive\n",
        "    drive.mount('/content/drive')\n",
        "\n",
        "from IPython.core.magic import register_cell_magic\n",
        "from IPython import get_ipython\n",
        "@register_cell_magic\n",
        "def skip_if(line, cell):\n",
        "    if eval(line):\n",
        "        return\n",
        "    get_ipython().run_cell(cell)\n",
        "\n",
        "\n",
        "%env RUNNING_IN_COLAB {RUNNING_IN_COLAB}\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kas5YtyluHb_"
      },
      "source": [
        "### Install Dependencies\n",
        "\n",
        "The environment will be setup with the code, models and required dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "982Yiv5tuHb_",
        "outputId": "2f570d1a-c6cc-49c3-c336-1d784d33a169"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading the repository...\n",
            "Branch 'pr/5' set up to track remote branch 'pr/5' from 'origin'.\n",
            "Requirement already satisfied: pip in /usr/local/lib/python3.11/dist-packages (24.3.1)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (75.8.0)\n",
            "Requirement already satisfied: wheel in /usr/local/lib/python3.11/dist-packages (0.45.1)\n",
            "Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu121\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 1)) (1.13.1)\n",
            "Requirement already satisfied: termcolor in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 2)) (2.5.0)\n",
            "Requirement already satisfied: addict in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 3)) (2.4.0)\n",
            "Requirement already satisfied: yapf==0.40.1 in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 4)) (0.40.1)\n",
            "Requirement already satisfied: timm in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 5)) (1.0.13)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 6)) (1.26.4)\n",
            "Requirement already satisfied: opencv-python in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 7)) (4.10.0.84)\n",
            "Requirement already satisfied: pycocotools in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 8)) (2.0.8)\n",
            "Requirement already satisfied: colorlog in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 9)) (6.9.0)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 10)) (75.8.0)\n",
            "Requirement already satisfied: ushlex in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 11)) (0.99.1)\n",
            "Requirement already satisfied: gradio<5,>=4.0.0 in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 12)) (4.44.1)\n",
            "Requirement already satisfied: gradio-image-prompter in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 13)) (0.1.0)\n",
            "Requirement already satisfied: spaces in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 14)) (0.32.0)\n",
            "Collecting filetype (from -r requirements.txt (line 15))\n",
            "  Downloading filetype-1.2.0-py2.py3-none-any.whl.metadata (6.5 kB)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 16)) (4.67.1)\n",
            "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 18)) (2.5.1+cu121)\n",
            "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 19)) (0.20.1+cu121)\n",
            "Requirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 20)) (4.47.1)\n",
            "Requirement already satisfied: importlib-metadata>=6.6.0 in /usr/local/lib/python3.11/dist-packages (from yapf==0.40.1->-r requirements.txt (line 4)) (8.5.0)\n",
            "Requirement already satisfied: platformdirs>=3.5.1 in /usr/local/lib/python3.11/dist-packages (from yapf==0.40.1->-r requirements.txt (line 4)) (4.3.6)\n",
            "Requirement already satisfied: tomli>=2.0.1 in /usr/local/lib/python3.11/dist-packages (from yapf==0.40.1->-r requirements.txt (line 4)) (2.2.1)\n",
            "Requirement already satisfied: pyyaml in /usr/local/lib/python3.11/dist-packages (from timm->-r requirements.txt (line 5)) (6.0.2)\n",
            "Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.11/dist-packages (from timm->-r requirements.txt (line 5)) (0.27.1)\n",
            "Requirement already satisfied: safetensors in /usr/local/lib/python3.11/dist-packages (from timm->-r requirements.txt (line 5)) (0.5.2)\n",
            "Requirement already satisfied: matplotlib>=2.1.0 in /usr/local/lib/python3.11/dist-packages (from pycocotools->-r requirements.txt (line 8)) (3.10.0)\n",
            "Requirement already satisfied: aiofiles<24.0,>=22.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (23.2.1)\n",
            "Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (3.7.1)\n",
            "Requirement already satisfied: fastapi<1.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.115.6)\n",
            "Requirement already satisfied: ffmpy in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.5.0)\n",
            "Requirement already satisfied: gradio-client==1.3.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (1.3.0)\n",
            "Requirement already satisfied: httpx>=0.24.1 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.28.1)\n",
            "Requirement already satisfied: importlib-resources<7.0,>=1.3 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (6.5.2)\n",
            "Requirement already satisfied: jinja2<4.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (3.1.5)\n",
            "Requirement already satisfied: markupsafe~=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.1.5)\n",
            "Requirement already satisfied: orjson~=3.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (3.10.14)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (24.2)\n",
            "Requirement already satisfied: pandas<3.0,>=1.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.2.2)\n",
            "Requirement already satisfied: pillow<11.0,>=8.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (10.4.0)\n",
            "Requirement already satisfied: pydantic>=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.10.5)\n",
            "Requirement already satisfied: pydub in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.25.1)\n",
            "Requirement already satisfied: python-multipart>=0.0.9 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.0.20)\n",
            "Requirement already satisfied: ruff>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.9.2)\n",
            "Requirement already satisfied: semantic-version~=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.10.0)\n",
            "Requirement already satisfied: tomlkit==0.12.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.12.0)\n",
            "Requirement already satisfied: typer<1.0,>=0.12 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.15.1)\n",
            "Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (4.12.2)\n",
            "Requirement already satisfied: urllib3~=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.3.0)\n",
            "Requirement already satisfied: uvicorn>=0.14.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.34.0)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from gradio-client==1.3.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2024.10.0)\n",
            "Requirement already satisfied: websockets<13.0,>=10.0 in /usr/local/lib/python3.11/dist-packages (from gradio-client==1.3.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (12.0)\n",
            "Requirement already satisfied: psutil<6,>=2 in /usr/local/lib/python3.11/dist-packages (from spaces->-r requirements.txt (line 14)) (5.9.5)\n",
            "Requirement already satisfied: requests<3.0,>=2.19 in /usr/local/lib/python3.11/dist-packages (from spaces->-r requirements.txt (line 14)) (2.32.3)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (3.16.1)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (3.4.2)\n",
            "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (12.1.105)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (12.1.105)\n",
            "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (12.1.105)\n",
            "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (9.1.0.70)\n",
            "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (12.1.3.1)\n",
            "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (11.0.2.54)\n",
            "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (10.3.2.106)\n",
            "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (11.4.5.107)\n",
            "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (12.1.0.106)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (2.21.5)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (12.1.105)\n",
            "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (3.1.0)\n",
            "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (1.13.1)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.11/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->-r requirements.txt (line 18)) (12.6.85)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch->-r requirements.txt (line 18)) (1.3.0)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers->-r requirements.txt (line 20)) (2024.11.6)\n",
            "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers->-r requirements.txt (line 20)) (0.21.0)\n",
            "Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.11/dist-packages (from anyio<5.0,>=3.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (3.10)\n",
            "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio<5.0,>=3.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (1.3.1)\n",
            "Requirement already satisfied: starlette<0.42.0,>=0.40.0 in /usr/local/lib/python3.11/dist-packages (from fastapi<1.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.41.3)\n",
            "Requirement already satisfied: certifi in /usr/local/lib/python3.11/dist-packages (from httpx>=0.24.1->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2024.12.14)\n",
            "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/dist-packages (from httpx>=0.24.1->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (1.0.7)\n",
            "Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.11/dist-packages (from httpcore==1.*->httpx>=0.24.1->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.14.0)\n",
            "Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.11/dist-packages (from importlib-metadata>=6.6.0->yapf==0.40.1->-r requirements.txt (line 4)) (3.21.0)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (1.3.1)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (0.12.1)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (4.55.3)\n",
            "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (1.4.8)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (3.2.1)\n",
            "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0,>=1.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2024.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0,>=1.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2024.2)\n",
            "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from pydantic>=2.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.7.0)\n",
            "Requirement already satisfied: pydantic-core==2.27.2 in /usr/local/lib/python3.11/dist-packages (from pydantic>=2.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.27.2)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests<3.0,>=2.19->spaces->-r requirements.txt (line 14)) (3.4.1)\n",
            "Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.11/dist-packages (from typer<1.0,>=0.12->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (8.1.8)\n",
            "Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from typer<1.0,>=0.12->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (1.5.4)\n",
            "Requirement already satisfied: rich>=10.11.0 in /usr/local/lib/python3.11/dist-packages (from typer<1.0,>=0.12->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (13.9.4)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.7->matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (1.17.0)\n",
            "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (3.0.0)\n",
            "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.18.0)\n",
            "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.1.2)\n",
            "Downloading filetype-1.2.0-py2.py3-none-any.whl (19 kB)\n",
            "Installing collected packages: filetype\n",
            "Successfully installed filetype-1.2.0\n",
            "inside get_extensions\n",
            "/usr/local/cuda/\n",
            "running build\n",
            "running build_py\n",
            "copying modules/ms_deform_attn.py -> build/lib.linux-x86_64-cpython-311/modules\n",
            "copying modules/__init__.py -> build/lib.linux-x86_64-cpython-311/modules\n",
            "copying functions/__init__.py -> build/lib.linux-x86_64-cpython-311/functions\n",
            "copying functions/ms_deform_attn_func.py -> build/lib.linux-x86_64-cpython-311/functions\n",
            "running build_ext\n",
            "Processing /content/countgd/models/GroundingDINO/ops\n",
            "  Preparing metadata (setup.py): started\n",
            "  Preparing metadata (setup.py): finished with status 'done'\n",
            "Building wheels for collected packages: MultiScaleDeformableAttention\n",
            "  Building wheel for MultiScaleDeformableAttention (setup.py): started\n",
            "  Building wheel for MultiScaleDeformableAttention (setup.py): finished with status 'done'\n",
            "  Created wheel for MultiScaleDeformableAttention: filename=MultiScaleDeformableAttention-1.0-cp311-cp311-linux_x86_64.whl size=2968787 sha256=98492f743278ba68af0ab7dad6642b06a78f2e0633591b993addc80f682f3157\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-fen8yoft/wheels/64/89/f6/145aead02469873626bfd93ef130da365da640c570c39dd3d5\n",
            "Successfully built MultiScaleDeformableAttention\n",
            "Installing collected packages: MultiScaleDeformableAttention\n",
            "  Attempting uninstall: MultiScaleDeformableAttention\n",
            "    Found existing installation: MultiScaleDeformableAttention 1.0\n",
            "    Uninstalling MultiScaleDeformableAttention-1.0:\n",
            "      Successfully uninstalled MultiScaleDeformableAttention-1.0\n",
            "Successfully installed MultiScaleDeformableAttention-1.0\n",
            "* True check_forward_equal_with_pytorch_double: max_abs_err 8.67e-19 max_rel_err 2.35e-16\n",
            "* True check_forward_equal_with_pytorch_float: max_abs_err 4.66e-10 max_rel_err 1.13e-07\n",
            "* True check_gradient_numerical(D=30)\n",
            "* True check_gradient_numerical(D=32)\n",
            "* True check_gradient_numerical(D=64)\n",
            "* True check_gradient_numerical(D=71)\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "+ '[' True == True ']'\n",
            "+ echo 'Downloading the repository...'\n",
            "+ '[' '!' -d /content/countgd ']'\n",
            "+ cd /content/countgd\n",
            "+ git fetch origin refs/pr/5:refs/remotes/origin/pr/5\n",
            "From https://huggingface.co/spaces/nikigoli/countgd\n",
            " * [new ref]         refs/pr/5  -> origin/pr/5\n",
            "+ git checkout pr/5\n",
            "Switched to a new branch 'pr/5'\n",
            "+ pip install --upgrade pip setuptools wheel\n",
            "+ pip install -r requirements.txt\n",
            "+ export CUDA_HOME=/usr/local/cuda/\n",
            "+ CUDA_HOME=/usr/local/cuda/\n",
            "+ cd models/GroundingDINO/ops\n",
            "+ python3 setup.py build\n",
            "/usr/local/lib/python3.11/dist-packages/torch/utils/cpp_extension.py:497: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.\n",
            "  warnings.warn(msg.format('we could not find ninja.'))\n",
            "/usr/local/lib/python3.11/dist-packages/torch/utils/cpp_extension.py:416: UserWarning: The detected CUDA version (12.2) has a minor version mismatch with the version that was used to compile PyTorch (12.1). Most likely this shouldn't be a problem.\n",
            "  warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version, torch.version.cuda))\n",
            "/usr/local/lib/python3.11/dist-packages/torch/utils/cpp_extension.py:426: UserWarning: There are no x86_64-linux-gnu-g++ version bounds defined for CUDA version 12.2\n",
            "  warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}')\n",
            "+ pip install .\n",
            "+ python3 test.py\n"
          ]
        }
      ],
      "source": [
        "%%bash\n",
        "\n",
        "set -euxo pipefail\n",
        "\n",
        "if [ \"${RUNNING_IN_COLAB}\" == \"True\" ]; then\n",
        "  echo \"Downloading the repository...\"\n",
        "  if [ ! -d /content/countgd ]; then\n",
        "    git clone \"https://huggingface.co/spaces/nikigoli/countgd\" /content/countgd\n",
        "  fi\n",
        "  cd /content/countgd\n",
        "  git fetch origin refs/pr/5:refs/remotes/origin/pr/5\n",
        "  git checkout pr/5\n",
        "else\n",
        "  # TODO check if cwd is the correct git repo\n",
        "  # If users use vscode, then we set the default start directory to root of the repo\n",
        "  echo \"Running in $(pwd)\"\n",
        "fi\n",
        "\n",
        "# TODO check for gcc-11 or above\n",
        "\n",
        "# Install pip packages\n",
        "pip install --upgrade pip setuptools wheel\n",
        "pip install -r requirements.txt\n",
        "\n",
        "# Compile modules\n",
        "export CUDA_HOME=/usr/local/cuda/\n",
        "cd models/GroundingDINO/ops\n",
        "python3 setup.py build\n",
        "pip install .\n",
        "python3 test.py"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%cd {\"/content/countgd\" if RUNNING_IN_COLAB else '.'}"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "58iD_HGnvcRJ",
        "outputId": "fe356a68-dced-4f6f-93cc-d83da2f84e28"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/countgd\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gH7A8zthuHb_"
      },
      "source": [
        "## Inference"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IspbBV0XuHb_"
      },
      "source": [
        "### Loading the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5nBT_HCUuHb_",
        "outputId": "95ceb6c6-bee8-4921-8bff-d28937045f78"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Some weights of BertModel were not initialized from the model checkpoint at checkpoints/bert-base-uncased and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "final text_encoder_type: checkpoints/bert-base-uncased\n",
            "load tokenizer done.\n",
            "final text_encoder_type: checkpoints/bert-base-uncased\n",
            "load tokenizer done.\n"
          ]
        }
      ],
      "source": [
        "import app\n",
        "import importlib\n",
        "importlib.reload(app)\n",
        "from app import (\n",
        "    build_model_and_transforms,\n",
        "    get_device,\n",
        "    get_args_parser,\n",
        "    generate_heatmap,\n",
        "    predict,\n",
        ")\n",
        "args = get_args_parser().parse_args([])\n",
        "device = get_device()\n",
        "model, transform = build_model_and_transforms(args)\n",
        "model = model.to(device)\n",
        "\n",
        "run = lambda image, text: predict(model, transform, image, text, None, device)\n",
        "get_output = lambda image, boxes: (len(boxes), generate_heatmap(image, boxes))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gfjraK3vuHb_"
      },
      "source": [
        "### Input / Output Utils\n",
        "\n",
        "Helper functions for reading / writing to zipfiles and csv"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qg0g5B-fuHb_"
      },
      "outputs": [],
      "source": [
        "import io\n",
        "import csv\n",
        "from pathlib import Path\n",
        "from contextlib import contextmanager\n",
        "import zipfile\n",
        "import filetype\n",
        "from PIL import Image\n",
        "logger = logging.getLogger()\n",
        "\n",
        "def images_from_zipfile(p: Path):\n",
        "    if not zipfile.is_zipfile(p):\n",
        "        raise ValueError(f'{p} is not a zipfile!')\n",
        "\n",
        "    with zipfile.ZipFile(p, 'r') as zipf:\n",
        "        def process_entry(info: zipfile.ZipInfo):\n",
        "            with zipf.open(info) as f:\n",
        "                if not filetype.is_image(f):\n",
        "                    logger.debug(f'Skipping file - {info.filename} as it is not an image')\n",
        "                    return\n",
        "                # Try loading the file\n",
        "                try:\n",
        "                    with Image.open(f) as im:\n",
        "                        im.load()\n",
        "                        return (info.filename, im)\n",
        "                except:\n",
        "                    logger.exception(f'Error reading file {info.filename}')\n",
        "\n",
        "        num_files = sum(1 for info in zipf.infolist() if info.is_dir() == False)\n",
        "        logger.info(f'Found {num_files} file(s) in the zip')\n",
        "        yield from (process_entry(info) for info in zipf.infolist() if info.is_dir() == False)\n",
        "\n",
        "@contextmanager\n",
        "def zipfile_writer(p: Path):\n",
        "    with zipfile.ZipFile(p, 'w') as zipf:\n",
        "        def write_output(image, image_filename):\n",
        "            buf = io.BytesIO()\n",
        "            image.save(buf, 'PNG')\n",
        "            zipf.writestr(image_filename, buf.getvalue())\n",
        "        yield write_output\n",
        "\n",
        "@contextmanager\n",
        "def csvfile_writer(p: Path):\n",
        "    with p.open('w', newline='') as csvfile:\n",
        "        fieldnames = ['filename', 'count']\n",
        "        csv_writer = csv.DictWriter(csvfile, fieldnames = fieldnames)\n",
        "        csv_writer.writeheader()\n",
        "\n",
        "        yield csv_writer.writerow"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rFXRk-_uuHb_"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "import os\n",
        "def process_zipfile(input_zipfile: Path, text: str):\n",
        "    if not input_zipfile.exists() or not input_zipfile.is_file() or not os.access(input_zipfile, os.R_OK):\n",
        "        logger.error(f'Cannot open / read zipfile: {input_zipfile}. Please check if it exists')\n",
        "        return\n",
        "\n",
        "    if text == \"\":\n",
        "        logger.error('Please provide the object you would like to count')\n",
        "        return\n",
        "\n",
        "    output_zipfile = input_zipfile.parent / f'{input_zipfile.stem}_countgd.zip'\n",
        "    output_csvfile = input_zipfile.parent / f'{input_zipfile.stem}.csv'\n",
        "\n",
        "    logger.info(f'Writing outputs to {output_zipfile.name} and {output_csvfile.name} in {input_zipfile.parent} folder')\n",
        "    with zipfile_writer(output_zipfile) as add_to_zip, csvfile_writer(output_csvfile) as write_row:\n",
        "        for filename, im in tqdm(images_from_zipfile(input_zipfile)):\n",
        "            boxes, _ = run(im, text)\n",
        "            count, heatmap  = get_output(im, boxes)\n",
        "            write_row({'filename': filename, 'count': count})\n",
        "            add_to_zip(heatmap, filename)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TmqsSxrsuHb_"
      },
      "source": [
        "### Run"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "ZaN918EkuHb_"
      },
      "outputs": [],
      "source": [
        "# @title ## Parameters { display-mode: \"form\", run: \"auto\" }\n",
        "# @markdown Set the following options to pass to the CountGD Model\n",
        "\n",
        "# @markdown ---\n",
        "# @markdown ### Enter a file path to a zip:\n",
        "zipfile_path = \"/content/drive/My Drive/test_images.zip\" # @param {type:\"string\"}\n",
        "# @markdown\n",
        "# @markdown ### Which object would you like to count?\n",
        "prompt = \"strawberry\" # @param {type:\"string\"}\n",
        "# @markdown ---"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 66,
          "referenced_widgets": [
            "929c22dba6df44c9adedf1c11de22c99",
            "002f1c72fa6a4e229bcc79e74ab88605",
            "85ae470ce6164ade87f27af3fe885e11"
          ]
        },
        "id": "fd-ShBCsuHb_",
        "outputId": "34601ba0-f961-4c20-fa95-e28440f3bb22"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Button(description='Run', style=ButtonStyle())"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "929c22dba6df44c9adedf1c11de22c99"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "ERROR:root:Cannot open / read zipfile: /content/drive/My Drive/test_images.zip. Please check if it exists\n"
          ]
        }
      ],
      "source": [
        "import ipywidgets as widgets\n",
        "from IPython.display import display\n",
        "button = widgets.Button(description=\"Run\")\n",
        "\n",
        "def on_button_clicked(b):\n",
        "    # Display the message within the output widget.\n",
        "    process_zipfile(Path(zipfile_path), prompt)\n",
        "\n",
        "button.on_click(on_button_clicked)\n",
        "display(button)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.7"
    },
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "9wyM6J2HuHb-",
        "jn061Tl8uHb-",
        "kas5YtyluHb_",
        "IspbBV0XuHb_",
        "gfjraK3vuHb_"
      ],
      "gpuType": "T4"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "929c22dba6df44c9adedf1c11de22c99": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ButtonModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ButtonModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ButtonView",
            "button_style": "",
            "description": "Run",
            "disabled": false,
            "icon": "",
            "layout": "IPY_MODEL_002f1c72fa6a4e229bcc79e74ab88605",
            "style": "IPY_MODEL_85ae470ce6164ade87f27af3fe885e11",
            "tooltip": ""
          }
        },
        "002f1c72fa6a4e229bcc79e74ab88605": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "85ae470ce6164ade87f27af3fe885e11": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ButtonStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ButtonStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "button_color": null,
            "font_weight": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
